import os
import json
import pandas as pd

import torch
from PIL import Image
from torchvision import transforms

from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 待修改参数
    # imgs_root
    # weights_path
    imgs_root = "../../data_set/flower_data/flower_photos/roses"   #--- 指向需要遍历预测的图像文件夹
    weights_path = "./weights/model-0.pth"     #---train训练好的模型参数【务必与指定模型匹配】

    # 1. 加载图片并预处理
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    assert os.path.exists(imgs_root), f"file: '{imgs_root}' dose not exist."
    # 读取指定文件夹下所有jpg图像路径
    #img_path_list = [os.path.join(imgs_root, i) for i in os.listdir(imgs_root) if i.endswith(".jpg")]
    img_path_list = []          # 所有jpg图像路径
    img_names_list = []         # 获取图像名(不含后缀)
    for i in os.listdir(imgs_root):
        if i.endswith(".jpg"):
            img_path_list.append(os.path.join(imgs_root, i))
            img_names_list.append(os.path.splitext(i)[0])


    # 2. 创建模型并加载训练好的模型权重
    # create model
    model = resnet34(num_classes=5).to(device)

    # load model weights
    assert os.path.exists(weights_path), f"file: '{weights_path}' dose not exist."
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # read class_indict
    json_path = './class_indices.json'     #--索引标签对应文件：如果不存在，就按正常的索引号输出
    class_indict = []

    if os.path.exists(json_path):
        # 标签字典存在的时候，对应标签
        with open(json_path, "r") as f:
            class_indict = json.load(f)
    else:
        print("file: '{}' dose not exist.".format(json_path))

    # prediction
    model.eval()
    batch_size = 8  # 每次预测时将多少张图片打包成一个batch
    # 存放结果数据框
    colnames = ['image_name','class','prob','image_path']       # 存放数据的列名
    result = pd.DataFrame(columns=colnames)      # 存放预测结果
    result.to_csv("predict_result.csv", index=False)  # 首次创建并清空文件
    
    with torch.no_grad():
        # 一个batch
        for ids in range(0, len(img_path_list) // batch_size):
            img_list = []       # 读取图像数据
            #img_names = []      # 获取图像名(不含后缀)
            for img_path in img_path_list[ids * batch_size: (ids + 1) * batch_size]:
                assert os.path.exists(img_path), f"file: '{img_path}' dose not exist."
                img = Image.open(img_path)
                img = data_transform(img)
                img_list.append(img)

                # img_name = os.path.splitext(os.path.basename(img_path))[0]
                # img_names.append(img_name)


            # batch img
            # 将img_list列表中的所有图像打包成一个batch
            batch_img = torch.stack(img_list, dim=0)
            # predict class
            output = model(batch_img.to(device)).cpu()
            predict = torch.softmax(output, dim=1)
            probs, classes = torch.max(predict, dim=1)

            # 判断是否有标签索引
            if len(class_indict) == 0 :
                class_indict = range(len(predict))      # 若没有就用预测类别长度索引           

            # 输出预测结果
            result_batch = pd.DataFrame(columns=colnames)      # 存放一个batch的预测结果

            for idx, (pro, cla) in enumerate(zip(probs, classes)):
                # 保存预测结果为数据框
                result_batch_dict = {
                    'image_name':img_names_list[ids * batch_size + idx],
                    'class':class_indict[str(cla.numpy())],
                    'prob':pro.numpy(),
                    'image_path':img_path_list[ids * batch_size + idx]
                }
                result_batch = result_batch.append(result_batch_dict, ignore_index = True)
                # 打印
                print("image: {}  class: {}  prob: {:.3}".format(img_names_list[ids * batch_size + idx],
                                                                 class_indict[str(cla.numpy())],
                                                                 pro.numpy()))
            
            # # pd保存为csv本地文件【每个batch后，追加保存一次】
            # result_batch.to_csv("predict_result.csv",index=False, 
            #                     mode='a',header=False)   # 追加
            # 每个result_batch都结果，追加到总结果中
            result = result.append(result_batch)

    # pd保存为csv本地文件【全部预测后，一次性保存】
    result.to_csv("predict_result.csv",index=False,header=False)
            

if __name__ == '__main__':
    main()
